#include "mtcnn.hpp"
#include "comm_lib.hpp"


void nms_boxes(std::vector<face_box>& input, float threshold, int type, std::vector<face_box>&output)
{

	std::sort(input.begin(),input.end(),
			[](const face_box& a, const face_box&b) {
			return a.score > b.score;  
			});

	int box_num=input.size();

	std::vector<int> merged(box_num,0);

	for(int i=0;i<box_num;i++)
	{ 
		if(merged[i])
			continue;

		output.push_back(input[i]);

		float h0=input[i].y1-input[i].y0+1;
		float w0=input[i].x1-input[i].x0+1;

		float area0=h0*w0;


		for(int j=i+1;j<box_num;j++)
		{
			if(merged[j]) 
				continue;

			float inner_x0=std::max(input[i].x0,input[j].x0);
			float inner_y0=std::max(input[i].y0,input[j].y0);

			float inner_x1=std::min(input[i].x1,input[j].x1);
			float inner_y1=std::min(input[i].y1,input[j].y1);

			float inner_h=inner_y1-inner_y0+1;
			float inner_w=inner_x1-inner_x0+1;


			if(inner_h<=0 || inner_w<=0)
				continue;

			float inner_area=inner_h*inner_w;

			float h1=input[j].y1-input[j].y0+1;
			float w1=input[j].x1-input[j].x0+1;

			float area1=h1*w1;

			float score;

			if(type == NMS_UNION)
			{
				score=inner_area/(area0+area1-inner_area);
			}
			else
			{
				score=inner_area/std::min(area0,area1);
			}


			if(score>threshold)
				merged[j]=1;
		}


	}


}

void regress_boxes(std::vector<face_box>& rects)
{
	for(unsigned int i=0;i<rects.size();i++)
	{
		face_box& box=rects[i];

		float h=box.y1-box.y0+1;
		float w=box.x1-box.x0+1;

		box.x0=box.x0+w*box.regress[0];
		box.y0=box.y0+h*box.regress[1];
		box.x1=box.x1+w*box.regress[2];
		box.y1=box.y1+h*box.regress[3];
	}    

}

void square_boxes(std::vector<face_box>& rects)
{

	for(unsigned int i=0;i<rects.size();i++)
	{
		float h=rects[i].y1-rects[i].y0+1;
		float w=rects[i].x1-rects[i].x0+1;

		float l=std::max(h,w);

		rects[i].x0=rects[i].x0+(w-l)*0.5;
		rects[i].y0=rects[i].y0+(h-l)*0.5;
		rects[i].x1=rects[i].x0+l-1;
		rects[i].y1=rects[i].y0+l-1;
	}
}

void padding(int img_h, int img_w, std::vector<face_box>& rects)
{
	for(unsigned int i=0; i<rects.size();i++)
	{
		rects[i].px0=std::max(rects[i].x0,1.0f);
		rects[i].py0=std::max(rects[i].y0,1.0f);
		rects[i].px1=std::min(rects[i].x1,(float)img_w);
		rects[i].py1=std::min(rects[i].y1,(float)img_h);
	}
} 


void process_boxes(std::vector<face_box>& input, int img_h, int img_w, std::vector<face_box>& rects)
{

	nms_boxes(input,0.7,NMS_UNION,rects); 

	regress_boxes(rects);

	square_boxes(rects);

	padding(img_h,img_w,rects);

} 


void generate_bounding_box(const float * confidence_data, int confidence_size,
		const float * reg_data, float scale, float threshold, 
		int feature_h, int feature_w, std::vector<face_box>&  output, bool transposed)
{

	int stride = 2;
	int cellSize = 12;

	int img_h= feature_h;
	int img_w = feature_w;

	int count = confidence_size/ 2;
	confidence_data += count;

	for (int i = 0; i<count; i++){
		if (*(confidence_data + i) >= threshold){
			int y = i / img_w;
			int x = i - img_w * y;

			float top_x = (int)((x*stride + 1) / scale);
			float top_y = (int)((y*stride + 1) / scale);
			float bottom_x = (int)((x*stride + cellSize) / scale);
			float bottom_y = (int)((y*stride + cellSize) / scale);


			face_box box;


			box.x0 = top_x;
			box.y0 = top_y;
			box.x1 = bottom_x;
			box.y1 = bottom_y;

			box.score = *(confidence_data + i);

			int c_offset=y*img_w+x;
			int c_size=img_w*img_h;

			if(transposed)
			{

				box.regress[1]=reg_data[c_offset];
				box.regress[0]=reg_data[c_offset+c_size]; 
				box.regress[3]=reg_data[c_offset+2*c_size];
				box.regress[2]= reg_data[c_offset+3*c_size];
			}
			else {

				box.regress[0]=reg_data[c_offset];
				box.regress[1]=reg_data[c_offset+c_size]; 
				box.regress[2]=reg_data[c_offset+2*c_size];
				box.regress[3]= reg_data[c_offset+3*c_size];
			}

			output.push_back(box);
		}
	}

}


void set_input_buffer(std::vector<cv::Mat>& input_channels,
		float* input_data, const int height, const int width) 
{
	for (int i = 0; i < 3; ++i) {
		cv::Mat channel(height, width, CV_32FC1, input_data);
		input_channels.push_back(channel);
		input_data += width * height;
	}
}


void  cal_pyramid_list(int height, int width, int min_size, float factor,std::vector<scale_window>& list)
{
	int min_side = std::min(height, width);
	double m = 12.0 / min_size;

	min_side=min_side*m;
	double cur_scale=1.0;

	double scale;


	while (min_side >= 12)
	{
		scale=m*cur_scale;
		cur_scale=cur_scale *factor; 
		min_side *= factor;

		int hs = std::ceil(height*scale);
		int ws = std::ceil(width*scale);

		scale_window win;
		win.h=hs;
		win.w=ws;
		win.scale=scale;
		list.push_back(win);
	}

}

void cal_landmark(std::vector<face_box>& box_list)
{
	for(unsigned int i=0;i<box_list.size();i++)
	{
		face_box& box=box_list[i];

		float w=box.x1-box.x0+1;
		float h=box.y1-box.y0+1;

		for(int j=0;j<5;j++)
		{
			box.landmark.x[j]=box.x0+w*box.landmark.x[j]-1;
			box.landmark.y[j]=box.y0+h*box.landmark.y[j]-1;
		}

	}
}

void set_box_bound(std::vector<face_box>& box_list, int img_h, int img_w)
{
	for(unsigned int i=0; i<box_list.size();i++)
	{
		face_box& box=box_list[i];

		box.x0=std::max((int)(box.x0),1);
		box.y0=std::max((int)(box.y0),1);
		box.x1=std::min((int)(box.x1),img_w-1);
		box.y1=std::min((int)(box.y1),img_h-1);

	}

}
